import cdtools
from cdtools.tools import plotting as p
from scipy.io import loadmat, savemat
from helper_functions import fourier_pad
import numpy as np
import torch as t
import argparse

device='cuda'

to_pad = 200

rpi_filename = 'data/single_shot_rpi.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(rpi_filename,
                                                    cut_zeros=False)

# We add 10 to the dataset, to enable our algorithm to find the
# low-intensity flat-field which otherwise goes negative
dataset.patterns = t.clamp(dataset.patterns + 10, min=0)

# We pad everything here for so we can use the probe recovered via the
# super-resolution ptychography. Using the padding here doesn't increase
# the resolution of resulting RPI, but it does prevent aliasing from light
# scattered off of the edge of the detector and matches the forward model
# between the ptychography and RPI reconstructions
to_pad = 200
dataset.patterns = t.nn.functional.pad(dataset.patterns,
                                       (to_pad ,to_pad, to_pad, to_pad))
dataset.mask = t.nn.functional.pad(dataset.mask,
                                   (to_pad ,to_pad, to_pad, to_pad), value=0)

# We load the probe and background from a ptycho calibration
init_file = f'results/single_shot_ptycho_full_polished_{to_pad}pad.mat'
ptycho_results = loadmat(init_file)
probe = t.as_tensor(ptycho_results['probe'])
probe = fourier_pad(probe, to_pad)
background = t.as_tensor(ptycho_results['background'])


# The number of pixels across the recovered RPI image
resolution = 650
savefile = f'results/single_shot_rpi_2modes_{resolution}.mat'
resolution = [resolution, resolution]

# Now we reconstruct each pattern in the dataset
objs = []
for frame in range(len(dataset)):
    print('Now working on frame',frame+1)
    
    # We create an RPI model from the dataset
    # Note that we explicitly use two incoherent object modes
    model = cdtools.models.RPI.from_dataset(dataset, probe, resolution,
                                            background=background, n_modes=2,
                                            initialization='uniform',
                                            weight_matrix=False,
                                            probe_threshold=0.01,
                                            mask=dataset.mask)

    # We restruct the object to the region where the probe has support
    model.obj.data = model.obj.data * model.obj_support

    # We move everything to the right device
    model.to(device=device)
    dataset.get_as(device=device)
    
    # A few iterations of Adam works to initialize the reconstruction
    # The regularization is an L2 regularizer that empirically helps accelerate
    # convergence. By increasing the regularization on the second mode, we
    # drive the object reconstruction into the top mode, which is the result
    for loss in model.Adam_optimize(50, dataset, subset=[frame], lr=0.04,
                                    regularization_factor=[30,300]):
        print(model.report())

    # And we converge to the final result with L-BFGS.
    last_loss = None
    for loss in model.LBFGS_optimize(25, dataset, subset=[frame], lr=0.4, regularization_factor=[30, 300], line_search_fn='strong_wolfe'):
        print(model.report())
        if loss == last_loss:
            break
        else:
            last_loss = loss

    objs.append(model.obj[0].detach().cpu().numpy())
    

objs = np.array(objs)
results = {'basis': model.obj_basis.cpu().numpy(),
           'objs': objs}

savemat(savefile, results)

